from dataset.dataset_2D import Dataset_2D
from dataset.dataset_3D import Dataset_3D

def get_dataset(args):
    if args.dataset == 'languagetable':
        if args.do_evaluate:
            return None, Dataset_2D(args, mode=args.mode)
        elif args.debug:
            return Dataset_2D(args,mode='val'), Dataset_2D(args, mode='val')
        else:
            return Dataset_2D(args,mode='train'), Dataset_2D(args,mode='val')
    elif args.dataset == 'rt1' or  args.dataset == 'droid' or  args.dataset == 'bridge':
        if args.do_evaluate:
            return None, Dataset_3D(args,mode=args.mode)
        elif args.debug:
            return Dataset_3D(args,mode='val'), Dataset_3D(args,mode='val')
        else:
            return Dataset_3D(args,mode='train'), Dataset_3D(args,mode='val')
    else:
        raise NotImplementedError(args.dataset)